Skip to content

Validate B/scales/zero_points shape in MatMulNBits::PrePack#29445

Open
apsonawane wants to merge 4 commits into
mainfrom
asonawane/edge-3
Open

Validate B/scales/zero_points shape in MatMulNBits::PrePack#29445
apsonawane wants to merge 4 commits into
mainfrom
asonawane/edge-3

Conversation

@apsonawane

Copy link
Copy Markdown
Contributor

MatMulNBits::PrePack ran at session initialization and called the MLAS pack routines using byte counts derived from the node attributes (N, K, bits, block_size) without ever comparing those attributes to the actual tensor Shape(). A crafted .onnx whose attributes overstate the real B (or scales / zero_points) extent triggered a heap-buffer-overflow READ inside MlasQNBitGemmPackQuantBData / MlasLutGemmPack during OrtApis::CreateSession (no Run() required).

The canonical shape check already lives in
matmul_nbits_helper::CheckInputs, but is invoked only from Compute() -- after PrePack has already done the OOB read, and by then the original B tensor is replaced with nullptr in the kernel context so the Compute-time check never re-validates it.

Fix: at the top of PrePack, after the existing early-return guards and before any tensor.DataRaw() read, validate the incoming initializer's Shape() against the attribute-derived shape:

  • B -> (N, k_blocks, blob_size)
  • scales -> (N * k_blocks) or (N, k_blocks)
  • zero_points -> uint8: (N * zp_blob) or (N, zp_blob); else
    (N * k_blocks) or (N, k_blocks)

A mismatch returns INVALID_ARGUMENT so the session fails to load rather than reading past the buffer.

MatMulNBits::PrePack ran at session initialization and called the MLAS
pack routines using byte counts derived from the node attributes
(N, K, bits, block_size) without ever comparing those attributes to
the actual tensor Shape(). A crafted .onnx whose attributes overstate
the real B (or scales / zero_points) extent triggered a
heap-buffer-overflow READ inside MlasQNBitGemmPackQuantBData /
MlasLutGemmPack during OrtApis::CreateSession (no Run() required).

The canonical shape check already lives in
matmul_nbits_helper::CheckInputs, but is invoked only from Compute()
-- after PrePack has already done the OOB read, and by then the
original B tensor is replaced with nullptr in the kernel context so
the Compute-time check never re-validates it.

Fix: at the top of PrePack, after the existing early-return guards
and before any tensor.DataRaw() read, validate the incoming
initializer's Shape() against the attribute-derived shape:

  - B           -> (N, k_blocks, blob_size)
  - scales      -> (N * k_blocks) or (N, k_blocks)
  - zero_points -> uint8: (N * zp_blob) or (N, zp_blob); else
                   (N * k_blocks) or (N, k_blocks)

A mismatch returns INVALID_ARGUMENT so the session fails to load
rather than reading past the buffer.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR hardens the CPU MatMulNBits contrib op against malformed models by adding early shape validation in MatMulNBits<T1>::PrePack() so that session initialization rejects inconsistent initializers before any MLAS packing routine can read past the provided buffers.

Changes:

  • Add attribute-derived initializer shape checks for B, scales, and zero_points at the top of MatMulNBits<T1>::PrePack().
  • Add new unit tests that expect session creation to fail (pre-Compute()) for mismatched initializer shapes, plus a compatibility test for legacy flattened scales/zero_points layouts.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc Adds new PrePack-time shape validation intended to prevent OOB reads during weight packing at session init.
onnxruntime/test/contrib_ops/matmul_4bits_test.cc Adds tests that exercise PrePack-time rejection for malformed initializer shapes and verifies legacy flattened layouts remain accepted.

Comment thread onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc Outdated

@tianleiwu tianleiwu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review: Validate B/scales/zero_points shape in MatMulNBits::PrePack

Verdict: LGTM (COMMENT). This correctly closes a real heap-buffer-overflow READ reachable at CreateSession time (no Run() required). The analysis holds up:

  • The validation block is placed at the very top of PrePack, after is_packed = false and before every tensor.DataRaw() / constant-tensor read on all paths (LUT and non-LUT, x64 and ARM64). This is the key correctness property and it is satisfied.
  • Running it ahead of the has_g_idx_ / unquantized-ZP / !MlasIsQNBitGemmAvailable early-returns is the right call — it makes bad-shape models fail consistently even on configs (e.g. Win x86 32-bit) where PrePack would otherwise short-circuit and the B tensor is dropped before Compute()'s CheckInputs runs.
  • Validating the constant scales/zero_points during the B prepack (via TryGetConstantInput) is necessary and correct: the B pack path dereferences those tensors before their own PrePack calls run, so per-tensor validation alone would be too late. The gating (has_zp_arg_ && has_zp_input_) matches the conditions under which the pack routines actually read ZP, so no over- or under-validation.
  • The derived shapes match matmul_nbits_helper::CheckInputs exactly: B (N, k_blocks, blob_size), scales [N*k_blocks]/[N,k_blocks], uint8 ZP [N*zp_blob]/[N,zp_blob] with zp_blob=(k_blocks*bits+7)/8, else [N*k_blocks]/[N,k_blocks]. INVALID_ARGUMENT return makes session load fail cleanly.
  • Test coverage is good: bad B extent, wrong B rank, bad scales, bad uint8 ZP (all expect "MatMulNBits PrePack:" failure before Compute()), plus a positive test that legacy flattened 1D scales/ZP layouts still load and compute correctly (no backward-compat regression).

Two minor, non-blocking observations below.

Nitpick (completeness, not a security gap): For non-uint8 (float) zero_points, CheckInputs additionally rejects a zero_points whose element type differs from scales; this guard checks only the shape. It is not an OOB risk (the LUT float-ZP path dispatches on the ZP's own dtype and enforces its size), and Compute()'s CheckInputs still catches the dtype mismatch — so this is purely a note, no change required.

// the pack routines below dereference tensor.DataRaw(). The MLAS pack routines size their reads
// from the (N, K, bits, block_size) attributes; without this check a crafted model whose
// attributes overstate the real tensor extents would trigger a heap-buffer-overflow READ at
// session initialization. The matching guard in matmul_nbits_helper::CheckInputs is invoked

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion (maintainability): the shape math here (k_blocks, blob_size, zp_blob_size_uint8) and the accepted layouts duplicate matmul_nbits_helper::CheckInputs. The cross-reference comment helps, but the two can silently drift if the canonical layout ever changes (e.g. a new packing scheme). Since these are constexpr-style derivations, consider factoring the layout math into a small shared helper in matmul_nbits_helper.h that both this guard and CheckInputs call, so a future layout change updates one place. (Reusing CheckInputs directly would change the error strings the new tests assert on, so a shared derivation helper is the lower-friction option.) Non-blocking.

"MatMulNBits PrePack: zero_points initializer shape ", s,
" does not match attribute-derived shape [", n * zp_blob_size_uint8, "] or [",
n, ",", zp_blob_size_uint8, "]");
} else {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: this non-uint8 branch validates the ZP shape but, unlike CheckInputs, does not verify zero_points and scales share the same element type. Not a security gap (no OOB stems from the dtype alone, and Compute()'s CheckInputs still rejects it), so this is just a completeness note — no change required.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants